import json
import os


def is_target_function(func):
    if func.getName() == 'SoftwareBus_dispatchNormalEPMsgOut':
        return True
    return False


def find_kmalloc_calls_in_function(func):
    """Returns addresses of kmalloc calls in a list."""
    addr = []
    called = func.getCalledFunctions(monitor)  # noqa

    for call in called.iterator():
        if call.getName() == '__kmalloc':
            for ref in getReferencesTo(call.getEntryPoint()):  # noqa
                if func.getBody().contains(ref.getFromAddress()):
                    print('found __kmalloc call @ {}'.format(ref.getFromAddress()))
                    addr.append(ref.getFromAddress())
    return addr


def get_decompiler():
    flat_api = ghidra.program.flatapi.FlatProgramAPI(getCurrentProgram(), getMonitor())  # noqa  # noqa
    decompiler_api = ghidra.app.decompiler.flatapi.FlatDecompilerAPI(flat_api)  # noqa
    decompiler_api.initialize()
    return decompiler_api.getDecompiler()


def get_pcode_mnemonics(block):
    """Returns Pcode representation of a block, for debugging."""
    result = {}
    for op in block.getIterator():
        mnemonic = op.getMnemonic()
        result.setdefault(mnemonic, 0)
        result[mnemonic] += 1
    return result


def get_function_blocks(function):
    """Returns the basic blocks of a function."""
    decompiler = get_decompiler()
    function_decompiler = decompiler.decompileFunction(function, 120, getMonitor())  # noqa
    high_function = function_decompiler.getHighFunction()

    return list(high_function.getBasicBlocks())


def get_block_in_indexes(block):
    """Returns the indexes of basic blocks before the provided block."""
    return [block.getIn(i).getIndex() for i in range(block.getInSize())]


def find_blocks_with_kmalloc_call(blocks, kmalloc_calls):
    """Returns blocks containing a kmalloc call."""
    blocks_with_kmalloc = []
    for block in blocks:
        for kmalloc_call in kmalloc_calls:
            if block.contains(kmalloc_call):
                blocks_with_kmalloc.append(block)
    print('kmalloc call(s) in basic blocks: {}'.format([i.getIndex() for i in blocks_with_kmalloc]))

    return blocks_with_kmalloc


def has_less_than_branch(block):
    """Checks if a basic block contains a INT_LESS instruction."""
    pcodes = list(block.getIterator())

    for pcode in pcodes:
        if pcode.getMnemonic() == 'INT_LESS':
            return True

    return False


def create_result_json(kmalloc_addresses, blocks_with_kmalloc_call, parents, is_vulnerable):
    """Creates a results JSON file, for further use."""
    output_file = '/io/result.json'
    with open(output_file, '+w') as file:
        json.dump(
            {
                'kmalloc_addresses': [f.toString() for f in kmalloc_addresses],
                'kmalloc_basic_blocks': [b.getIndex() for b in blocks_with_kmalloc_call],
                'parent_blocks': parents,
                'is_vulnerable': is_vulnerable,
            },
            file,
        )
    os.chmod(output_file, 0o666)  # assure access rights to file created inside docker container


def _find_target_function():
    print('Searching for CVE-2021-45608 related function: SoftwareBus_dispatchNormalEPMsgOut...')
    function = getFirstFunction()  # noqa
    while function is not None:
        if is_target_function(function):
            print('found {} at {}'.format(function.getName(), function.getEntryPoint().toString()))
            break
        function = getFunctionAfter(function)  # noqa
    return function


def main():
    function = _find_target_function()
    if function is None:
        print('could not find function.')
        return

    # find kmalloc calls within function
    kmalloc_addresses = find_kmalloc_calls_in_function(function)

    # find basic blocks containing kmalloc calls
    blocks = get_function_blocks(function)
    blocks_with_kmalloc_call = find_blocks_with_kmalloc_call(blocks, kmalloc_addresses)

    if len(blocks_with_kmalloc_call) == 0:
        print('could not find any blocks with kmalloc.')
        return

    # get parent blocks of blocks with kmalloc call
    parents = []
    for block in blocks_with_kmalloc_call:
        parents.extend(get_block_in_indexes(block))
    print('found parent blocks: {}'.format(parents))

    # check parent blocks for fix
    print('check if parent blocks contain the fix...')
    is_vulnerable = True
    for block in blocks:
        if block.getIndex() in parents:
            if has_less_than_branch(block):
                is_vulnerable = False
                break  # only checks if one path is fixed

    create_result_json(kmalloc_addresses, blocks_with_kmalloc_call, parents, is_vulnerable)


if __name__ == '__main__':
    main()
